ScatterElements
返回一个新tensor,根据指定索引和更新值对input中的元素进行指定操作(替换或相加)。不支持隐式类型转换。举例:一个三维输入tensor的返回为:
output[indices[i][j][k]][j][k] = updates[i][j][k] #if axis == 0, reduction == "none"
output[i][indices[i][j][k]][k] += updates[i][j][k] #if axis == 1, reduction == "add"
output[i][j][indices[i][j][k]] = updates[i][j][k] #if axis == 2, reduction == "none"
- 输入:
input - 输入数据的地址
indices - 指定索引。
updates - 更新值。
param - 算子计算所需参数的结构体。其各成员见下述。
core_mask - 核掩码。
ScatterElementsParameter定义:
1typedef struct ScatterElementsParameter {
2 int* indices_stride_; // 对应于indices数组每一维度的步长
3 int* output_stride_; // 对应于output数组每一维度的步长
4 int input_dims_; // 输入张量的维度数
5 int axis_; // 指定索引所在的轴
6 int input_axis_size_; // 索引所在轴的元素数
7 int indices_total_num_; // indices数组的总元素数
8 int input_total_num_; // input数组的总元素数
9 int reduction_type_; // 规约类型,0代表none,1代表add
10} ScatterElementsParameter;
- 输出:
output - 输出地址。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持int8, int16, int32, fp32, fp64, cplx64, cplx128
MT7004 支持fp16, fp32, int16, int32, cplx64
如果 indices 中有多个索引向量对应于同一位置,则输出中该位置值是不确定的。
如果 indices 的值超出 input 索引上下界,则相应的 updates 不会更新到 input,也不会抛出索引错误。
共享存储版本:
-
void i8_scatter_elements_s(int8_t *input, int8_t *output, int *indices, int8_t *updates, ScatterElementsParameter *param, int core_mask)
-
void i16_scatter_elements_s(int16_t *input, int16_t *output, int *indices, int16_t *updates, ScatterElementsParameter *param, int core_mask)
-
void i32_scatter_elements_s(int *input, int *output, int *indices, int *updates, ScatterElementsParameter *param, int core_mask)
-
void hp_scatter_elements_s(half *input, half *output, int *indices, half *updates, ScatterElementsParameter *param, int core_mask)
-
void fp_scatter_elements_s(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
-
void dp_scatter_elements_s(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)
-
void c64_scatter_elements_s(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
-
void c128_scatter_elements_s(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)
C调用示例:
1void PackParam(ScatterElementsParameter* param, int* indices_shape, int* input_shape) {
2 param->indices_stride_[param->input_dims_ - 1] = 1;
3 int i;
4 for (i = param->input_dims_ - 1; i > 0; --i) {
5 param->indices_stride_[i - 1] = param->indices_stride_[i] * indices_shape[i];
6 }
7 param->output_stride_[param->input_dims_ - 1] = 1;
8 for (i = param->input_dims_ - 1; i > 0; --i) {
9 param->output_stride_[i - 1] = param->output_stride_[i] * input_shape[i];
10 }
11 param->indices_total_num_ = 1;
12 for (i = 0; i < param->input_dims_; i++) {
13 param->indices_total_num_ *= indices_shape[i];
14 }
15 param->input_total_num_ = 1;
16 for (i = 0; i < param->input_dims_; i++) {
17 param->input_total_num_ *= input_shape[i];
18 }
19 param->input_axis_size_ = input_shape[param->axis_];
20}
21
22void TestScatterElementsSMC(int* input_shape, int* indices_shape, int ndim, int axis, int reduction_type, int core_mask) {
23 int core_num = GetCoreNum(core_mask);
24 int core_id = get_core_id();
25 int logic_core_id = GetLogicCoreId(core_mask, core_id);
26 void* input_data = (void*)0x88000000;
27 void* output_data = (void*)0x98000000;
28 int* indices_data = (int*)0xA8000000;
29 void* updates_data = (void*)0xB8000000;
30 ScatterElementsParameter* param = (ScatterElementsParameter*)0xC8000000;
31 if (logic_core_id == 0) {
32 param->axis_ = axis;
33 param->input_dims_ = ndim;
34 param->indices_stride_ = (int*)0xC8020000;
35 param->output_stride_ = (int*)0xC8040000;
36 param->reduction_type_ = reduction_type;
37 PackParam(param, indices_shape, input_shape);
38 }
39 sys_bar(0, core_num); // 初始化参数完成后进行同步
40 fp_scatter_elements_s(input_data, output_data, indices_data, updates_data, param, core_mask);
41}
42
43void main() {
44 int input_shape[2] = {8, 30};
45 int indices_shape[2] = {3, 3};
46 int ndim = 2;
47 int axis = 0;
48 int reduction_type = 0;
49 int core_mask = 0b1111;
50 TestScatterElementsSMC(input_shape, indices_shape, ndim, axis, reduction_type, core_mask);
51}
私有存储版本:
-
void i8_scatter_elements_p(int8_t *input, int8_t *output, int *indices, int8_t *updates, ScatterElementsParameter *param, int core_mask)
-
void i16_scatter_elements_p(int16_t *input, int16_t *output, int *indices, int16_t *updates, ScatterElementsParameter *param, int core_mask)
-
void i32_scatter_elements_p(int *input, int *output, int *indices, int *updates, ScatterElementsParameter *param, int core_mask)
-
void hp_scatter_elements_p(half *input, half *output, int *indices, half *updates, ScatterElementsParameter *param, int core_mask)
-
void fp_scatter_elements_p(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
-
void dp_scatter_elements_p(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)
-
void c64_scatter_elements_p(float *input, float *output, int *indices, float *updates, ScatterElementsParameter *param, int core_mask)
-
void c128_scatter_elements_p(double *input, double *output, int *indices, double *updates, ScatterElementsParameter *param, int core_mask)
C调用示例:
1void PackParam(ScatterElementsParameter* param, int* indices_shape, int* input_shape) {
2 param->indices_stride_[param->input_dims_ - 1] = 1;
3 int i;
4 for (i = param->input_dims_ - 1; i > 0; --i) {
5 param->indices_stride_[i - 1] = param->indices_stride_[i] * indices_shape[i];
6 }
7 param->output_stride_[param->input_dims_ - 1] = 1;
8 for (i = param->input_dims_ - 1; i > 0; --i) {
9 param->output_stride_[i - 1] = param->output_stride_[i] * input_shape[i];
10 }
11 param->indices_total_num_ = 1;
12 for (i = 0; i < param->input_dims_; i++) {
13 param->indices_total_num_ *= indices_shape[i];
14 }
15 param->input_total_num_ = 1;
16 for (i = 0; i < param->input_dims_; i++) {
17 param->input_total_num_ *= input_shape[i];
18 }
19 param->input_axis_size_ = input_shape[param->axis_];
20}
21
22void TestScatterElementsL2(int* input_shape, int* indices_shape, int ndim, int axis, int reduction_type, int core_mask) {
23 void* input_data = (void*)0x10000000; // 私有存储版本地址设置在AM内
24 void* output_data = (void*)0x10001000;
25 int* indices_data = (int*)0x10002000;
26 void* updates_data = (void*)0x10003000;
27 ScatterElementsParameter* param = (ScatterElementsParameter*)0x10004000;
28 param->axis_ = axis;
29 param->input_dims_ = ndim;
30 param->indices_stride_ = (int*)0x10005000;
31 param->output_stride_ = (int*)0x10006000;
32 param->reduction_type_ = reduction_type;
33 PackParam(param, indices_shape, input_shape);
34 fp_scatter_elements_p(input_data, output_data, indices_data, updates_data, param, core_mask);
35}
36
37void main() {
38 int input_shape[2] = {8, 30};
39 int indices_shape[2] = {3, 3};
40 int ndim = 2;
41 int axis = 0;
42 int reduction_type = 0;
43 int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动
44 TestScatterElementsL2(input_shape, indices_shape, ndim, axis, reduction_type, core_mask);
45}